package nsalt

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.bus._
import nsalt.decode._
import nsalt.func._

// Found at:
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/backend/seq/EXU.scala
// 
// Docs:
// https://oscpu.gitbook.io/nutshell/liu-shui-xian-she-ji-xi-jie/exu

class Execute extends Module with Config with Exception {
  
  val io = IO(new Bundle {
    val in = Flipped(Decoupled(new DecodePort))
    val flush = Input(Bool())
    val dmem = new BusUncached(addrBits = VIRT_MEM_ADDR_LEN)
    val forward = new ForwardPort
    val memMMU = Flipped(new MemUniPort)
  
    val out = Decoupled(new CommitPort)
  })

  val src1 = io.in.bits.data.src1(XLEN - 1, 0)
  val src2 = io.in.bits.data.src2(XLEN - 1, 0)

  val funcType = io.in.bits.ctrlSgnl.funcType
  val oper = io.in.bits.ctrlSgnl.operType

  val funcValid = Wire(Vec(FuncType.num, Bool()))
  for (i <- 0 until FuncType.num) {
    funcValid(i) := (funcType === i.U) && io.in.valid && !io.flush
  }

  val alu = Module(new ArithLogic)
  val aluOut = alu.access(
    valid = funcValid(FuncType.alu), 
    src1 = src1, 
    src2 = src2, 
    oper = oper
  )

  alu.io.alu.ctrlFlow   := io.in.bits.ctrlFlow
  alu.io.alu.offset     := io.in.bits.data.imm
  // func unit port
  alu.io.alu.out.ready  := true.B

  def isBru(func: UInt) = func(4)

  val lsu = Module(new LoadStore)
  val lsuTLBPageFault = WireInit(false.B)
  
  val lsuOut = lsu.access(
    valid = funcValid(FuncType.lsu),
    src1 = src1,
    src2 = io.in.bits.data.imm,
    oper = oper,
    dtlbPageFault = lsuTLBPageFault
  )

  lsu.io.dataW := src2
  lsu.io.instr := io.in.bits.ctrlFlow.instr
  io.out.bits.isMMIO := lsu.io.isMMIO || (AddressSpace.isMMIO(io.in.bits.ctrlFlow.pc) && io.out.valid)
  io.dmem <> lsu.io.dmem
  lsu.io.out.ready := true.B

  val mdu = Module(new MulDiv)
  val mduOut = mdu.access(
    valid = funcValid(FuncType.mdu), 
    src1 = src1, 
    src2 = src2, 
    oper = oper
  )
  mdu.io.out.ready := true.B

  // val csr = if (Settings.get("MmodeOnly")) Module(new CSR_M) else Module(new CSR)
  val csr = Module(new CtrlStatusReg)
  val csrOut = csr.access(
    valid = funcValid(FuncType.csr),
    src1 = src1, 
    src2 = src2, 
    oper = oper
  )

  csr.io.ctrlFlow := io.in.bits.ctrlFlow
  csr.io.ctrlFlow.exceptionVec(LOAD_ADDR_MISALIGNED)  := lsu.io.loadAddrMisaligned
  csr.io.ctrlFlow.exceptionVec(STORE_ADDR_MISALIGNED) := lsu.io.storeAddrMisaligned
  csr.io.instrValid := io.in.valid && !io.flush

  // io.out.bits.intrNO := csr.io.intrNO
  csr.io.isBackendException := false.B
  csr.io.out.ready := true.B

  csr.io.imemMMU <> io.memMMU.imem
  csr.io.dmemMMU <> io.memMMU.dmem

  val mou = Module(new MemOrder)
  // mou does not write register
  mou.access(
    valid = funcValid(FuncType.mou),
    src1 = src1,
    src2 = src2, 
    oper = oper
  )
  mou.io.ctrlFlow  := io.in.bits.ctrlFlow
  mou.io.out.ready := true.B

  val lsuMisaligned = lsu.io.loadAddrMisaligned || lsu.io.storeAddrMisaligned
  
  // Some cleaning work done here, comparing to
  // https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/backend/seq/EXU.scala#L88-L92
  io.out.bits.decode := DontCare
  (io.out.bits.decode.ctrlSgnl, io.in.bits.ctrlSgnl) match { case (o, i) =>
    o.regEnableW := i.regEnableW && 
      (!funcValid(FuncType.lsu) || !(lsuTLBPageFault || lsuMisaligned) ) && 
      (!funcValid(FuncType.csr) || !csr.io.wenFix)
    o.destRef  := i.destRef
    o.funcType := i.funcType
  }

  io.out.bits.decode.ctrlFlow.pc := io.in.bits.ctrlFlow.pc
  io.out.bits.decode.ctrlFlow.instr  := io.in.bits.ctrlFlow.instr
  io.out.bits.decode.ctrlFlow.redirect <> Mux(mou.io.redirect.valid, 
    mou.io.redirect,
    Mux(csr.io.redirect.valid,
      csr.io.redirect,
      alu.io.alu.redirect
    )
  )
  
  // FIXME: should handle io.out.ready == false
  io.out.valid := io.in.valid && MuxLookup(funcType, true.B, List(
    FuncType.lsu -> lsu.io.out.valid,
    FuncType.mdu -> mdu.io.out.valid
  ))

  // Check out the definition of CommitPort
  io.out.bits.commits(FuncType.alu) := aluOut
  io.out.bits.commits(FuncType.lsu) := lsuOut
  io.out.bits.commits(FuncType.csr) := csrOut
  io.out.bits.commits(FuncType.mdu) := mduOut
  io.out.bits.commits(FuncType.mou) := 0.U

  io.in.ready := !io.in.valid || io.out.fire

  // This part will be connected to Issue Stage for further regfile update
  io.forward.valid            := io.in.valid
  io.forward.writeReg.enableW := io.in.bits.ctrlSgnl.regEnableW
  io.forward.writeReg.dest    := io.in.bits.ctrlSgnl.destRef
  io.forward.writeReg.data    := Mux(alu.io.alu.out.fire, aluOut, lsuOut)
  io.forward.funcType         := io.in.bits.ctrlSgnl.funcType

}
